Abstract:Dynamic Sparse Training (DST) methods train neural networks by maintaining sparsity while dynamically adapting the network topology. Despite the promise of reduced computation, DST methods converge significantly slower than dense training, often requiring comparable training time to achieve similar accuracy. We demonstrate both analytically and empirically that Batch Normalization (BN) adversely affects sparse training, and propose SparseOpt, a sparsity-aware optimizer, to address this. Experiments on ResNet models across CIFAR-100 and ImageNet demonstrate consistently faster convergence and improved generalization with our proposed method. Our work highlights the limitations of current normalization layers in sparse training and provides the first systematic study of the interaction between Batch Normalization, sparse layers, and DST, taking a significant step toward making DST practically competitive with dense training.
Abstract:Large language model (LLM) development is currently driven by large-scale empirical iteration over data mixtures, reward models, routing strategies, and evaluation pipelines. Here, we argue that many central questions in LLM development and evaluation are inherently causal: What is the effect of adding a data domain during pretraining? How do annotator preferences change when LLMs generate text in a different style? Should a prompt be routed to a larger or smaller model given inference cost constraints? In general, causal methods are well-suited to such settings where interventions change outcomes but, surprisingly, are underrepresented in LLM development. Our contribution is threefold: (1) We explain how causal methods can help develop modern LLM development and evaluation: LLM development relies heavily on logged data, which are often subject to confounding and distribution shifts; evaluation uses learned but potentially biased judges; and deployment environments are non-stationary. These conditions make purely predictive approaches fragile and create opportunities for principled identification and estimation methods from causal inference. (2) We further map opportunities for causal methods in the entire LLM development pipeline, including pretraining, alignment, routing, agentic workflows, and evaluation. (3) We discuss new research opportunities around leveraging causal methods for LLM development and evaluation. Overall, we argue that causal methods are potentially underutilized for the LLM development and evaluation pipeline, despite the fact that such methods can ensure a reliable and scientifically grounded design.
Abstract:We introduce CoMET, \textit{\textbf{C}omposing \textbf{M}odality \textbf{E}ncoders with \textbf{T}abular foundation models}, a simple yet highly competitive method for multimodal classification: pass each modality through a frozen pre-trained backbone, compress the resulting embeddings with PCA, and concatenate as input into a Tabular Foundation Model (TFM) for prediction. We show that PCA alone suffices to act as an adaptor yielding strong, robust performance across modalities. When the \texttt{CLS} tokens of the foundation model align poorly with downstream tasks, we propose \textbf{PALPooling}, a lightweight adaptive token pooler that consistently improves representation quality. By composing strong frozen representation learning backbones with TFMs, our approach achieves state-of-the-art results across diverse multimodal benchmarks without any training. On hierarchical tasks with large fine-grained class spaces, our approach enables fast and scalable classification, handling datasets with over 500,000 samples and 2,000 classes without any fine-tuning. Overall, our results show that the composition of foundation models is a simple, yet powerful, out-of-the-box solution for multimodal learning, challenging the necessity of complex, end-to-end training pipelines for new problems.
Abstract:Causal inference, estimating causal effects from observational data, is a fundamental tool in many disciplines. Of particular importance across a variety of domains is the continuous treatment setting, where the variable of intervention has a continuous range. This setting is far less explored and represents a substantial shift from the binary treatment setting, with models needing to represent effects across a continuum of treatment values. In this paper, we present the first causal foundation model for the continuous treatment setting. Our model meta-learns the ability to predict causal effects across a wide variety of unseen tasks without additional training or fine-tuning. First, we design a novel prior over data-generating processes with continuous treatment variables in order to generate a rich causal training corpus. We then train a transformer to reconstruct individual treatment-response curves given only observational data, leveraging in-context learning to amortize expensive Bayesian posterior inference. Our model achieves state-of-the-art performance on individual treatment-response curve reconstruction tasks compared to causal models which are trained specifically for those tasks.
Abstract:Irregularly sampled multivariate event streams remain a stubbornly difficult modality for generative modeling: tokenization-based approaches break down when inter-event intervals vary by orders of magnitude, and neural temporal point processes are bottlenecked by window-level numerical quadrature. We (i) propose SurF, a generative model that uses the Time Rescaling Theorem (TRT) as a learnable bijection between event sequences and i.i.d.\ unit-rate exponential noise, enabling a single model to be trained across heterogeneous event-stream datasets; (ii) three efficient parameterizations of the cumulative intensity that scale to long sequences; and (iii) a Transformer-based encoder for multi-dataset pretraining. On six real-world benchmarks, SurF achieves the best reported time RMSE on Earthquake, Retweet, and Taobao, and is within trial-level noise of the strongest specialist on the remaining three. Under a strict leave-one-out protocol, the held-out checkpoint beats every classical and neural-autoregressive baseline on 5/6 datasets and beats every baseline on Amazon and Earthquake, an initial step toward foundation models over asynchronous event streams.
Abstract:The instrumental-variables (IV) setting is standard for partial identification of causal effects when unobserved confounding makes point identification impossible. Existing approaches face methodological bottlenecks: closed-form bound estimands are required -- e.g., Balke-Pearl equations in binary IV -- and even when available, designing accurate estimators requires manual effort tailored to each estimand. While direct Bayesian inference of the causal effects, instead of the bounds, circumvents these challenges, it is often computationally intensive and suffers from high prior sensitivity or under-dispersed posteriors. As a remedy, we introduce IV-ICL, an amortized Bayesian in-context learning method that learns the marginal posterior distribution of the causal effects directly and derives bounds as its quantiles. Unlike standard variational inference that optimizes exclusive KL divergence, amortized Bayesian inference minimizes the expected inclusive KL, a mass-covering objective. We empirically observe that optimizing inclusive KL can recover the entire identified set across diverse data-generating processes, while exclusive-KL (e.g. with variational inference) of the same Bayesian formulation collapses onto a single mode and fails to cover the identified set. We evaluate IV-ICL on synthetic and semi-synthetic IV benchmarks and show it produces intervals that are more reliably valid and more informative compared to efficient semi-parametric, Bayesian, and plug-in baselines, at 20-500x lower inference time. Beyond methodology, we propose a procedure to convert randomized controlled trials into IV benchmarks with provably preserved ground-truth causal effects that enables a more realistic evaluation of partial-identification methods.
Abstract:Training machine learning models requires the storage of large datasets, which often contain sensitive or private data. Storing data is associated with a number of potential risks which increase over time, such as database breaches and malicious adversaries. Machine unlearning is the study of methods to efficiently remove the influence of training data subsets from previously-trained models. Existing unlearning methods typically require direct access to the "forget set" -- the data to be forgotten-and organisations must retain this data for unlearning rather than deleting it immediately upon request, increasing risks associated with the forget set. We introduce partially-blind unlearning -- utilizing auxiliary information to unlearn without explicit access to the forget set. We also propose a practical framework Reload, a partially-blind method based on gradient optimization and structured weight sparsification to operationalize partially-blind unlearning. We show that Reload efficiently unlearns, approximating models retrained from scratch, and outperforms several forget set-dependent approaches. On language models, Reload unlearns entities using <0.025% of the retain set and <7% of model weights in <8 minutes on Llama2-7B. In the corrective case, Reload achieves unlearning even when only 10% of corrupted data is identified.
Abstract:Deploying clinical ML is slow and brittle: models that work at one hospital often degrade under distribution shifts at the next. In this work, we study a simple question -- can large language models (LLMs) create portable patient embeddings i.e. representations of patients enable a downstream predictor built on one hospital to be used elsewhere with minimal-to-no retraining and fine-tuning. To do so, we map from irregular ICU time series onto concise natural language summaries using a frozen LLM, then embed each summary with a frozen text embedding model to obtain a fixed length vector capable of serving as input to a variety of downstream predictors. Across three cohorts (MIMIC-IV, HIRID, PPICU), on multiple clinically grounded forecasting and classification tasks, we find that our approach is simple, easy to use and competitive with in-distribution with grid imputation, self-supervised representation learning, and time series foundation models, while exhibiting smaller relative performance drops when transferring to new hospitals. We study the variation in performance across prompt design, with structured prompts being crucial to reducing the variance of the predictive models without altering mean accuracy. We find that using these portable representations improves few-shot learning and does not increase demographic recoverability of age or sex relative to baselines, suggesting little additional privacy risk. Our work points to the potential that LLMs hold as tools to enable the scalable deployment of production grade predictive models by reducing the engineering overhead.
Abstract:Evidence-based medicine (EBM) is central to high-quality care, but remains difficult to implement in fast-paced primary care settings. Physicians face short consultations, increasing patient loads, and lengthy guideline documents that are impractical to consult in real time. To address this gap, we investigate the feasibility of using large language models (LLMs) as ambient assistants that surface targeted, evidence-based questions during physician-patient encounters. Our study focuses on question generation rather than question answering, with the aim of scaffolding physician reasoning and integrating guideline-based practice into brief consultations. We implemented two prompting strategies, a zero-shot baseline and a multi-stage reasoning variant, using Gemini 2.5 as the backbone model. We evaluated on a benchmark of 80 de-identified transcripts from real clinical encounters, with six experienced physicians contributing over 90 hours of structured review. Results indicate that while general-purpose LLMs are not yet fully reliable, they can produce clinically meaningful and guideline-relevant questions, suggesting significant potential to reduce cognitive burden and make EBM more actionable at the point of care.
Abstract:Masked diffusion models (MDM) exhibit superior generalization when learned using a Partial masking scheme (Prime). This approach converts tokens into sub-tokens and models the diffusion process at the sub-token level. We identify two limitations of the MDM-Prime framework. First, we lack tools to guide the hyperparameter choice of the token granularity in the subtokenizer. Second, we find that the function form of the subtokenizer significantly degrades likelihood estimation when paired with commonly used Byte-Pair-Encoding (BPE) tokenizers. To address these limitations, we study the tightness of the variational bound in MDM-Prime and develop MDM-Prime-v2, a masked diffusion language model which incorporates Binary Encoding and Index Shuffling. Our scaling analysis reveals that MDM-Prime-v2 is 21.8$\times$ more compute-efficient than autoregressive models (ARM). In compute-optimal comparisons, MDM-Prime-v2 achieves 7.77 perplexity on OpenWebText, outperforming ARM (12.99), MDM (18.94), and MDM-Prime (13.41). When extending the model size to 1.1B parameters, our model further demonstrates superior zero-shot accuracy on various commonsense reasoning tasks.